<html><!-- Created using the cpp_pretty_printer from the dlib C++ library.  See http://dlib.net for updates. --><head><title>dlib C++ Library - krr_regression_ex.cpp</title></head><body bgcolor='white'><pre>
<font color='#009900'>// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
</font><font color='#009900'>/*
    This is an example illustrating the use of the kernel ridge regression 
    object from the dlib C++ Library.

    This example will train on data from the sinc function.

*/</font>

<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>iostream<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>vector<font color='#5555FF'>&gt;</font>

<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>dlib<font color='#5555FF'>/</font>svm.h<font color='#5555FF'>&gt;</font>

<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> std;
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> dlib;

<font color='#009900'>// Here is the sinc function we will be trying to learn with kernel ridge regression 
</font><font color='#0000FF'><u>double</u></font> <b><a name='sinc'></a>sinc</b><font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font> x<font face='Lucida Console'>)</font>
<b>{</b>
    <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>x <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>
        <font color='#0000FF'>return</font> <font color='#979000'>1</font>;
    <font color='#0000FF'>return</font> <font color='#BB00BB'>sin</font><font face='Lucida Console'>(</font>x<font face='Lucida Console'>)</font><font color='#5555FF'>/</font>x;
<b>}</b>

<font color='#0000FF'><u>int</u></font> <b><a name='main'></a>main</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<b>{</b>
    <font color='#009900'>// Here we declare that our samples will be 1 dimensional column vectors.  
</font>    <font color='#0000FF'>typedef</font> matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font>,<font color='#979000'>1</font>,<font color='#979000'>1</font><font color='#5555FF'>&gt;</font> sample_type;

    <font color='#009900'>// Now sample some points from the sinc() function
</font>    sample_type m;
    std::vector<font color='#5555FF'>&lt;</font>sample_type<font color='#5555FF'>&gt;</font> samples;
    std::vector<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>&gt;</font> labels;
    <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font> x <font color='#5555FF'>=</font> <font color='#5555FF'>-</font><font color='#979000'>10</font>; x <font color='#5555FF'>&lt;</font><font color='#5555FF'>=</font> <font color='#979000'>4</font>; x <font color='#5555FF'>+</font><font color='#5555FF'>=</font> <font color='#979000'>1</font><font face='Lucida Console'>)</font>
    <b>{</b>
        <font color='#BB00BB'>m</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> x;
        samples.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>m<font face='Lucida Console'>)</font>;
        labels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><font color='#BB00BB'>sinc</font><font face='Lucida Console'>(</font>x<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
    <b>}</b>

    <font color='#009900'>// Now we are making a typedef for the kind of kernel we want to use.  I picked the
</font>    <font color='#009900'>// radial basis kernel because it only has one parameter and generally gives good
</font>    <font color='#009900'>// results without much fiddling.
</font>    <font color='#0000FF'>typedef</font> radial_basis_kernel<font color='#5555FF'>&lt;</font>sample_type<font color='#5555FF'>&gt;</font> kernel_type;

    <font color='#009900'>// Here we declare an instance of the krr_trainer object.  This is the
</font>    <font color='#009900'>// object that we will later use to do the training.
</font>    krr_trainer<font color='#5555FF'>&lt;</font>kernel_type<font color='#5555FF'>&gt;</font> trainer;

    <font color='#009900'>// Here we set the kernel we want to use for training.   The radial_basis_kernel 
</font>    <font color='#009900'>// has a parameter called gamma that we need to determine.  As a rule of thumb, a good 
</font>    <font color='#009900'>// gamma to try is 1.0/(mean squared distance between your sample points).  So 
</font>    <font color='#009900'>// below we are using a similar value computed from at most 2000 randomly selected
</font>    <font color='#009900'>// samples.
</font>    <font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> gamma <font color='#5555FF'>=</font> <font color='#979000'>3.0</font><font color='#5555FF'>/</font><font color='#BB00BB'>compute_mean_squared_distance</font><font face='Lucida Console'>(</font><font color='#BB00BB'>randomly_subsample</font><font face='Lucida Console'>(</font>samples, <font color='#979000'>2000</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
    cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>using gamma of </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> gamma <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
    trainer.<font color='#BB00BB'>set_kernel</font><font face='Lucida Console'>(</font><font color='#BB00BB'>kernel_type</font><font face='Lucida Console'>(</font>gamma<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;

    <font color='#009900'>// now train a function based on our sample points
</font>    decision_function<font color='#5555FF'>&lt;</font>kernel_type<font color='#5555FF'>&gt;</font> test <font color='#5555FF'>=</font> trainer.<font color='#BB00BB'>train</font><font face='Lucida Console'>(</font>samples, labels<font face='Lucida Console'>)</font>;

    <font color='#009900'>// now we output the value of the sinc function for a few test points as well as the 
</font>    <font color='#009900'>// value predicted by our regression.
</font>    <font color='#BB00BB'>m</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#979000'>2.5</font>; cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>sinc</font><font face='Lucida Console'>(</font><font color='#BB00BB'>m</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>   </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>test</font><font face='Lucida Console'>(</font>m<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
    <font color='#BB00BB'>m</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#979000'>0.1</font>; cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>sinc</font><font face='Lucida Console'>(</font><font color='#BB00BB'>m</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>   </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>test</font><font face='Lucida Console'>(</font>m<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
    <font color='#BB00BB'>m</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#5555FF'>-</font><font color='#979000'>4</font>;  cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>sinc</font><font face='Lucida Console'>(</font><font color='#BB00BB'>m</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>   </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>test</font><font face='Lucida Console'>(</font>m<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
    <font color='#BB00BB'>m</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#979000'>5.0</font>; cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>sinc</font><font face='Lucida Console'>(</font><font color='#BB00BB'>m</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>   </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>test</font><font face='Lucida Console'>(</font>m<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;

    <font color='#009900'>// The output is as follows:
</font>    <font color='#009900'>//using gamma of 0.075
</font>    <font color='#009900'>//    0.239389   0.239389
</font>    <font color='#009900'>//    0.998334   0.998362
</font>    <font color='#009900'>//    -0.189201   -0.189254
</font>    <font color='#009900'>//    -0.191785   -0.186618
</font>
    <font color='#009900'>// The first column is the true value of the sinc function and the second
</font>    <font color='#009900'>// column is the output from the krr estimate.  
</font>

    <font color='#009900'>// Note that the krr_trainer has the ability to tell us the leave-one-out predictions
</font>    <font color='#009900'>// for each sample.  
</font>    std::vector<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>&gt;</font> loo_values;
    trainer.<font color='#BB00BB'>train</font><font face='Lucida Console'>(</font>samples, labels, loo_values<font face='Lucida Console'>)</font>;
    cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>mean squared LOO error: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>mean_squared_error</font><font face='Lucida Console'>(</font>labels,loo_values<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
    cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>R^2 LOO value:          </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>r_squared</font><font face='Lucida Console'>(</font>labels,loo_values<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
    <font color='#009900'>// Which outputs the following:
</font>    <font color='#009900'>// mean squared LOO error: 8.29575e-07
</font>    <font color='#009900'>// R^2 LOO value:          0.999995
</font>




    <font color='#009900'>// Another thing that is worth knowing is that just about everything in dlib is serializable.
</font>    <font color='#009900'>// So for example, you can save the test object to disk and recall it later like so:
</font>    <font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>saved_function.dat</font>"<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> test;

    <font color='#009900'>// Now let's open that file back up and load the function object it contains.
</font>    <font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>saved_function.dat</font>"<font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> test;

<b>}</b>



</pre></body></html>